{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0ead654a",
   "metadata": {},
   "source": [
    "# ReRank"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99f6a6f7",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "Load needed API keys and relevant Python libaries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f350cd1b",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "# !pip install cohere \n",
    "# !pip install weaviate-client"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2febbb9-27dd-4209-838a-99b4f9cdf51b",
   "metadata": {
    "height": 64
   },
   "outputs": [],
   "source": [
    "import os\n",
    "from dotenv import load_dotenv, find_dotenv\n",
    "_ = load_dotenv(find_dotenv()) # read local .env file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dab2ecba-3403-4317-86ef-bd6d92a6cb46",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "import cohere\n",
    "co = cohere.Client(os.environ['COHERE_API_KEY'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30737b1b-e4c8-4bd0-a04b-c2ce70d28821",
   "metadata": {
    "height": 64
   },
   "outputs": [],
   "source": [
    "import weaviate\n",
    "auth_config = weaviate.auth.AuthApiKey(\n",
    "    api_key=os.environ['WEAVIATE_API_KEY'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8781f638-17c7-4ab7-86b5-3763d4d5abad",
   "metadata": {
    "height": 132
   },
   "outputs": [],
   "source": [
    "client = weaviate.Client(\n",
    "    url=os.environ['WEAVIATE_API_URL'],\n",
    "    auth_client_secret=auth_config,\n",
    "    additional_headers={\n",
    "        \"X-Cohere-Api-Key\": os.environ['COHERE_API_KEY'],\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ffcc8e5e",
   "metadata": {},
   "source": [
    "## Dense Retrieval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8561fbf-035e-4856-a97f-8eda21d32a81",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "from utils import dense_retrieval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15694a5c-3525-49cc-b5e9-d1c34ae0fbe9",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "query = \"What is the capital of Canada?\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dfede25-8a43-41c9-9328-d331695c4fcb",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "dense_retrieval_results = dense_retrieval(query, client)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1822cc6c-ddc2-4938-b746-7cda2506d51e",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "from utils import print_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2990c5c4-1b63-453e-8dd8-8568cb7872f5",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "print_result(dense_retrieval_results)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db449134",
   "metadata": {},
   "source": [
    "## Improving Keyword Search with ReRank"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8071c68a-6dec-47f9-b5e1-473f9acdc83f",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "from utils import keyword_search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa47e41a-5988-4405-af7b-c7cb7382eed9",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "query_1 = \"What is the capital of Canada?\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e851efa5-10c7-4f98-85f1-2a1c565d9723",
   "metadata": {
    "height": 200
   },
   "outputs": [],
   "source": [
    "query_1 = \"What is the capital of Canada?\"\n",
    "results = keyword_search(query_1,\n",
    "                         client,\n",
    "                         properties=[\"text\", \"title\", \"url\", \"views\", \"lang\", \"_additional {distance}\"],\n",
    "                         num_results=3\n",
    "                        )\n",
    "\n",
    "for i, result in enumerate(results):\n",
    "    print(f\"i:{i}\")\n",
    "    print(result.get('title'))\n",
    "    print(result.get('text'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e1b2d2c",
   "metadata": {
    "height": 200
   },
   "outputs": [],
   "source": [
    "query_1 = \"What is the capital of Canada?\"\n",
    "results = keyword_search(query_1,\n",
    "                         client,\n",
    "                         properties=[\"text\", \"title\", \"url\", \"views\", \"lang\", \"_additional {distance}\"],\n",
    "                         num_results=500\n",
    "                        )\n",
    "\n",
    "for i, result in enumerate(results):\n",
    "    print(f\"i:{i}\")\n",
    "    print(result.get('title'))\n",
    "    #print(result.get('text'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b38761f8-32b1-4b44-be97-0884894cf6b3",
   "metadata": {
    "height": 149
   },
   "outputs": [],
   "source": [
    "def rerank_responses(query, responses, num_responses=10):\n",
    "    reranked_responses = co.rerank(\n",
    "        model = 'rerank-english-v2.0',\n",
    "        query = query,\n",
    "        documents = responses,\n",
    "        top_n = num_responses,\n",
    "        )\n",
    "    return reranked_responses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02d3e55c-0a5b-4b3a-9a59-3f7164927dc0",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "texts = [result.get('text') for result in results]\n",
    "reranked_text = rerank_responses(query_1, texts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b3a380b-cebf-47da-956d-dc62dc53e5a0",
   "metadata": {
    "height": 81
   },
   "outputs": [],
   "source": [
    "for i, rerank_result in enumerate(reranked_text):\n",
    "    print(f\"i:{i}\")\n",
    "    print(f\"{rerank_result}\")\n",
    "    print()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6cbb081",
   "metadata": {},
   "source": [
    "## Improving Dense Retrieval with ReRank"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be2e5378-ea37-4726-b3c3-5875d46759e7",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "from utils import dense_retrieval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5af11ea-6c30-4303-8c9e-8a5510e046bb",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "query_2 = \"Who is the tallest person in history?\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4da5c744-01b8-4780-a615-0a5edf9bfbd6",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "results = dense_retrieval(query_2,client)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e4540d8-ed5e-4f97-8802-6d39a52b8964",
   "metadata": {
    "height": 98
   },
   "outputs": [],
   "source": [
    "for i, result in enumerate(results):\n",
    "    print(f\"i:{i}\")\n",
    "    print(result.get('title'))\n",
    "    print(result.get('text'))\n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d269db28-15aa-426a-a993-14275a36ca09",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "texts = [result.get('text') for result in results]\n",
    "reranked_text = rerank_responses(query_2, texts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa7aca9b-bdc0-4c08-9615-1a7408854cb4",
   "metadata": {
    "height": 81
   },
   "outputs": [],
   "source": [
    "for i, rerank_result in enumerate(reranked_text):\n",
    "    print(f\"i:{i}\")\n",
    "    print(f\"{rerank_result}\")\n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef1a763a-1b1d-4d4a-99b4-26ab341663e9",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9000db7-3202-436f-a542-ae20b8879ea7",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08880d64-a78e-4871-bae3-e75ab88ac3ad",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd4bb517-e5a7-4a4f-bc2b-cb4ea2fad2bc",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c297e7ba-5a95-412e-9a4b-b5f214570cfe",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93763474-ea16-4a2e-a0ff-08c088f8c708",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8cbf0cd-1150-47d4-a33b-966267960dfc",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37f62842-d1e0-4388-8a1d-43e5cb1e6d05",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac7843fb-2d5a-49ed-9520-fd671333ee0c",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e353066b-7c07-42ad-bc75-ac56fc1a25b2",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bde2f0f-5d0a-4b3f-8739-d1d0ae8fadab",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa9a7eef-980b-4591-8c7c-1505a33a6f95",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4da8462-b32c-4aeb-8fe4-2370e499e7a1",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75a0466e-ffae-4c16-988c-862290a7b604",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8233d2cf-0cfd-4f59-98dc-838708dd8452",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aec3aead-d23b-47e1-8ec4-7c585f91a960",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "317bafdc-18e9-4c80-8098-79a35c83eb1f",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5707147-2c0c-4055-a179-e653ad9533c9",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1c633a4-daea-4d8a-983d-807fc612874b",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
